from torch.utils.data import Dataset
import torch
import pickle

class MyDataset(Dataset):
    """
    自定义数据集类，继承自Dataset
    重写__init__, __len__, __getitem__的魔法方法
    """
    def __init__(self, qa_ids_list, max_len):
        """
        初始化函数，用于设置数据集属性
        :param qa_ids_list: 从pkl文件中加载出来的id问答列表
        :param max_len:最大长度
        """
        super().__init__()
        self.qa_ids_list = qa_ids_list
        self.max_len = max_len

    def __len__(self):
        """
        返回数据集长度
        :return:数据集的长度
        """
        return len(self.qa_ids_list)

    def __getitem__(self, index):
        """
        根据指定索引返回数据集的对应的一个样本
        :param index:样本的索引
        :return:索引对应的问题id的张量
        """
        qa_ids = self.qa_ids_list[index]
        qa_ids = qa_ids[:self.max_len]
        qa_ids = torch.tensor(qa_ids, dtype=torch.long)
        return qa_ids

